package com.zlsy.es.dianping.service.impl;

import cn.hutool.core.util.StrUtil;
import com.baidu.unbiz.easymapper.MapperFactory;
import com.zlsy.es.dianping.pojo.model.Shop;
import com.zlsy.es.dianping.pojo.vo.response.ShopResponse;
import com.zlsy.es.dianping.service.EsService;
import com.zlsy.es.dianping.service.ShopService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.common.document.DocumentField;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.functionscore.FieldValueFactorFunctionBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.GaussDecayFunctionBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.sort.SortOrder;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.io.IOException;
import java.math.BigDecimal;
import java.util.*;

/**
 * @author zhouliang
 * @date 2020/03/10
 **/
@Slf4j
@Service
@RequiredArgsConstructor(onConstructor = @__(@Autowired))
public class EsServiceImpl implements EsService {

    private final RestHighLevelClient restHighLevelClient;
    private final ShopService shopService;

    @Override
    public Map<String, Object> searchES(BigDecimal longitude, BigDecimal latitude, String keyword, Integer orderBy, Integer categoryId, String tags) {

        Map<String, Object> result = new HashMap<>(16);

        SearchRequest searchRequest = new SearchRequest("shop");
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();

        //function_score

        //1.1、构建bool query
        BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery();
        boolQueryBuilder.must(QueryBuilders.termQuery("seller_disabled_flag", 0));
        //关键字
        if (StrUtil.isNotEmpty(keyword)) {
            boolQueryBuilder.must(QueryBuilders.matchQuery("name", keyword).boost(0.1f));
        }
        if (null != categoryId) {
            boolQueryBuilder.must(QueryBuilders.termQuery("category_id", categoryId));
        }
        if (StrUtil.isNotEmpty(tags)) {
            boolQueryBuilder.must(QueryBuilders.termQuery("tags", tags));
        }
        //默认排序
        if (null == orderBy) {
            FunctionScoreQueryBuilder.FilterFunctionBuilder[] filterFunctionBuilders = new FunctionScoreQueryBuilder.FilterFunctionBuilder[3];
            //2.1、构建高斯函数
            String origin = longitude + "," + latitude;
            ScoreFunctionBuilder<GaussDecayFunctionBuilder> gaussDecayFunctionBuilder = new GaussDecayFunctionBuilder("location", origin, "100km", "0km", 0.5);
            gaussDecayFunctionBuilder.setWeight(9);
            filterFunctionBuilders[0] = new FunctionScoreQueryBuilder.FilterFunctionBuilder(gaussDecayFunctionBuilder);
            //2.2、构建自定义的remark_score
            ScoreFunctionBuilder<FieldValueFactorFunctionBuilder> remarkScoreFunctionBuilder = new FieldValueFactorFunctionBuilder("remark_score");
            remarkScoreFunctionBuilder.setWeight(0.2f);
            filterFunctionBuilders[1] = new FunctionScoreQueryBuilder.FilterFunctionBuilder(remarkScoreFunctionBuilder);
            //2.3、构建自定义的seller_remark_score
            ScoreFunctionBuilder<FieldValueFactorFunctionBuilder> sellerRemarkScoreFunctionBuilder = new FieldValueFactorFunctionBuilder("seller_remark_score");
            sellerRemarkScoreFunctionBuilder.setWeight(0.1f);
            filterFunctionBuilders[2] = new FunctionScoreQueryBuilder.FilterFunctionBuilder(sellerRemarkScoreFunctionBuilder);

            FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders.functionScoreQuery(boolQueryBuilder, filterFunctionBuilders)
                    .boostMode(CombineFunction.SUM)
                    .scoreMode(FunctionScoreQuery.ScoreMode.SUM);

            searchSourceBuilder.query(functionScoreQueryBuilder)
                    .sort("_score", SortOrder.DESC);
        } else {
            FunctionScoreQueryBuilder.FilterFunctionBuilder[] filterFunctionBuilders = new FunctionScoreQueryBuilder.FilterFunctionBuilder[1];
            ScoreFunctionBuilder<FieldValueFactorFunctionBuilder> pricePerManFunctionBuilder = new FieldValueFactorFunctionBuilder("price_per_man");
            pricePerManFunctionBuilder.setWeight(1);
            filterFunctionBuilders[0] = new FunctionScoreQueryBuilder.FilterFunctionBuilder(pricePerManFunctionBuilder);

            FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders.functionScoreQuery(boolQueryBuilder, filterFunctionBuilders)
                    .boostMode(CombineFunction.REPLACE)
                    .scoreMode(FunctionScoreQuery.ScoreMode.SUM);
            searchSourceBuilder.query(functionScoreQueryBuilder)
                    .sort("_score", SortOrder.ASC);
        }


        //3.1 编写script_fields返回distance字段
        Map<String, Object> params = new HashMap<>(16);
        params.put("lat", longitude);
        params.put("lon", latitude);
        String lang = "expression";
        String idOrCode = "haversin(lat,lon,doc['location'].lat,doc['location'].lon)";
        Script script = new Script(ScriptType.INLINE, lang, idOrCode, params);

        //4.1 构建聚合
        TermsAggregationBuilder termsAggregationBuilder = AggregationBuilders.terms("group_by_tags").field("tags");


        //想要查询的字段
        String[] includeFields = new String[]{"id"};
        //排除的字段
        String[] excludeFields = new String[]{};

        searchSourceBuilder.scriptField("distance", script)
                .fetchSource(includeFields, excludeFields)
                .aggregation(termsAggregationBuilder);

        searchRequest.source(searchSourceBuilder);
        try {
            SearchResponse searchResponse = restHighLevelClient.search(searchRequest, RequestOptions.DEFAULT);
            if (RestStatus.OK.equals(searchResponse.status())) {
                List<ShopResponse> shopResponseList = new ArrayList<>();
                SearchHits hits = searchResponse.getHits();
                SearchHit[] searchHits = hits.getHits();
                for (SearchHit hit : searchHits) {
                    //获得ES主键=source里面的id
                    String id = hit.getId();
                    Shop shop = shopService.get(Integer.valueOf(id));
                    ShopResponse shopResponse = MapperFactory.getCopyByRefMapper().mapClass(Shop.class, ShopResponse.class)
                            .registerAndMap(shop, ShopResponse.class);
                    shopResponse.setCategoryModel(shop.getCategoryModel());
                    shopResponse.setSellerModel(shop.getSellerModel());

                    //获得距离
                    Map<String, DocumentField> fields = hit.getFields();
                    DocumentField distanceField = fields.get("distance");
                    BigDecimal disValue = new BigDecimal(distanceField.getValue().toString());
                    shopResponse.setDistance(disValue.multiply(new BigDecimal(1000).setScale(0,BigDecimal.ROUND_CEILING)).intValue());

                    shopResponseList.add(shopResponse);
                }

                List<Map<String, Object>> tagsList = new ArrayList<>();
                Aggregations aggregations = searchResponse.getAggregations();
                Terms byTagsAggregation = aggregations.get("group_by_tags");
                List<? extends Terms.Bucket> buckets = byTagsAggregation.getBuckets();
                for (Terms.Bucket bucket : buckets) {
                    Map<String, Object> tagMap = new HashMap<>(16);
                    long docCount = bucket.getDocCount();
                    String tag = (String) bucket.getKey();
                    tagMap.put("tags", tag);
                    tagMap.put("num", docCount);
                    tagsList.add(tagMap);
                }

                result.put("tags", tagsList);
                result.put("shop", shopResponseList);
            }
        } catch (IOException e) {
            log.error("{}:", e.getMessage());
        }
        return result;
    }

}
